Skip to content

[WIP][MLPerf 6.1][DLRMv4] Adding DLRMv4 HSTU#889

Draft
chriscai-amd wants to merge 121 commits into
mlcommons:masterfrom
chriscai-amd:chcai/dlrmv4-idfix
Draft

[WIP][MLPerf 6.1][DLRMv4] Adding DLRMv4 HSTU#889
chriscai-amd wants to merge 121 commits into
mlcommons:masterfrom
chriscai-amd:chcai/dlrmv4-idfix

Conversation

@chriscai-amd

@chriscai-amd chriscai-amd commented Jun 25, 2026

Copy link
Copy Markdown

Adds a new DLRMv4 MLPerf Training reference benchmark: an HSTU
(Hierarchical Sequential Transduction Units) generative recommender trained on
the public Yambda-5b dataset, vendored under recommendation_v4/ as a
sibling of the existing recommendation_v2/torchrec_dlrm.

What this adds

  • Model / framework: HSTU (generative_recommenders dlrm_v3 path) on
    TorchRec DistributedModelParallel, bf16 training, single listen_plus
    task. Architecture, history_length, hbm_cap, and embedding dims are
    gin-tunable with env overrides.
  • Dataset pipeline: preprocess_public_data.py downloads + preprocesses
    Yambda (50m / 500m / 5b) from HuggingFace (temporal GTS split, session
    segmentation, item popularity); DLRM_DATA_PATH env override.
  • Streaming (temporal-order) training: window-by-window train in strict time order,
    no future leakage. Eval is a strictly-future held-out window — either the next window (T → T+1)
    or a fixed holdout window/dataset (eval_holdout_ts, optional train/eval split).
    Persistent-loader + double-buffer + eval-prefetch hide window-reset overhead;
    crash-resumable checkpoints with step/time/window cadences and bit-equal RNG replay.
  • Hardware support: AMD MI350X (gfx950) Triton HSTU kernels enabled
    (TMA early-out, jagged multirow routing, separated-RNG LN-dropout, autotune
    pinning) plus NVIDIA B200 (sm_100). Multi-node SLURM launcher with GPUDirect
    RDMA.
  • MLPerf compliance: mllog integration centralized in
    mlperf_logging_utils.py (resume-aware INIT/RUN/EVAL markers, AUC convergence
    target 0.80275, AMD/MI355X submission identity); mlperf_logging pinned to
    6.0.0-rc6 in Dockerfiles + requirements.
  • Reference packaging: download_dataset.sh / verify_dataset.sh /
    run_and_time.sh wrappers, MLPerf-spec README (summaries, model+paper,
    hyperparameter table, quality target, eval frequency), frozen
    requirements.txt, and an RCP placeholder.
  • Docs: reproducible MI350X + B200 training recipes (container images,
    dependency versions, perf ladder) and perf/multi-node notes.

chriscai-amd and others added 30 commits May 29, 2026 16:21
…7e51c)

Vendored snapshot of chriscai-amd/generative-recommenders branch chcai/dlrmv4
(HEAD d97e51c) as a sibling of recommendation_v2/torchrec_dlrm. The Python
package generative_recommenders keeps its original name so all imports work
unchanged from the new location.

- recommendation_v4/generative_recommenders/: dlrm_v3, modules, ops, research, tests
- recommendation_v4/configs/: research HSTU gins
- recommendation_v4/scripts/launch_smoke_8gpu.sh: sanitized 8-GPU yambda-5b launcher
  (resolves package root from script path; AMD env defaults; pip_local override)
- recommendation_v4/{setup.py,requirements.txt,main.py,...}: upstream entry points
- .gitmodules: cutlass registered at parent repo level

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Four fixes unlocking the HSTU_HAMMER_KERNEL=TRITON path on MI350X:

1. triton_hstu_attention.py _should_enable_tma(): add HIP early-out.
   torch.cuda.get_device_capability() on gfx950 returns (9, 5) which would
   pass the major==9 Hopper check and trick the kernel into the TMA path,
   producing kernels that don't compile on ROCm.

2. triton_hstu_attention.py _get_fw_configs(): hoist the USE_TLX/NUM_BUFFERS/
   NUM_MMA_WARPS_PER_GROUP/NUM_MMA_GROUPS defaults loop out of the CUDA-only
   else: branch. The _hstu_attn_fwd signature requires these constexprs
   regardless of backend; missing them on HIP triggered TypeError:
   dynamic_func() missing N required positional arguments at autotune.
   Also gate the H100 TLX configs append on `not torch.version.hip`.

3. triton_jagged_tensors.py concat/split dispatch: route AMD/ROCm through
   *_2D_jagged_multirow instead of the basic _concat_2D_jagged /
   _split_2D_jagged kernels. The basic kernels fail PassManager::run at
   make_ttgir (TritonAMDGPUCanonicalizePointers pass) on ROCm; multirow
   compiles fine. NVIDIA non-Blackwell paths (H100/A100) are unchanged.

4. triton_jagged_tensors.py _Concat2DJaggedFunction.backward: replace the
   raw _split_2D_jagged[grid] call with _triton_split_2D_jagged_internal
   so the backward pass benefits from the same AMD multirow routing as
   the forward.

Verified end-to-end on 8x MI350X: yambda-5b bs=32 seq=4k at 782 global_sps
vs PYTORCH backend 547 sps -- 1.43x throughput, 75% peak VRAM vs 92%.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The attribute is absent in some Triton builds (e.g. nvcr.io/nvidia/pytorch:26.01-py3),
causing import-time AttributeError before any training step runs. Use getattr with
a False default so _use_meta_ws() gracefully reports disabled on those builds.
Three small changes so you can sweep model size and per-sample sequence
length from a gin file without editing configs.py.

configs.py:
  - get_hstu_configs is now @gin.configurable. Accepts optional overrides
    for max_seq_len, max_num_candidates, hstu_embedding_table_dim,
    hstu_transducer_embedding_dim, hstu_num_heads, hstu_attn_num_layers,
    hstu_attn_linear_dim, hstu_attn_qk_dim, hstu_input_dropout_ratio,
    hstu_linear_dropout_rate. Per-dataset defaults still apply unless
    explicitly overridden in gin.
  - get_embedding_table_config is now @gin.configurable with an
    embedding_dim override that uniformly sets the dim for all tables
    of the chosen dataset.
  - Drop the YAMBDA_EMBEDDING_DIM constant (was a duplicate of
    HSTU_EMBEDDING_DIM=512). Yambda branch now uses HSTU_EMBEDDING_DIM
    directly. Add a comment noting the model+table dim must stay aligned
    when overriding either via gin.

utils.py:
  - get_dataset accepts an optional history_length kwarg that wins over
    the yambda dataset's hardcoded default of 4096. Caches are still
    keyed on disk under hstu_cache_L<N>/ so switching L between previously
    built values is free.

train/gin/yambda_5b.gin:
  - Pin history_length=2048 and max_seq_len=2048 for the seq-2k smoke
    config. Both lines have inline comments explaining the +9 overhead
    (uid + 7 cross + 1 candidate) so total per-sample seq is ~2046,
    within the 2048 budget.

Verified: default codepath unchanged, gin overrides apply consistently
to both get_hstu_configs (model) and get_embedding_table_config (tables).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
make_optimizer_and_shard now accepts hbm_cap_gb (default 260, the MI350X
value) via @gin.configurable. The yambda gin pins the same default so
sweeps just change the number in the gin file instead of editing utils.py.

ddr_cap dropped from 32 GiB to 0: with all 11 yambda 5b embedding tables
fitting on 8x MI350X HBM, allowing host DRAM offload only invites the
planner to pick slower per-lookup-PCIe-traffic plans.

Verified gin binding flows through to the Topology: a probe with
hbm_cap_gb=100 produced Topology(hbm_cap=107374182400) and the planner
correctly raised insufficient-storage error at that tightness.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
preprocess_public_data.py:
  - Add DLRMYambdaProcessor: downloads Yambda multi_event + catalog
    metadata from the yandex/yambda HuggingFace repo, then runs a
    temporal split (300 train days / 30 min gap / 1 test day),
    builds per-user sessions (1800s inactivity threshold), and
    writes the layout DLRMv3YambdaDataset expects:
      <data-path>/raw/<size>/multi_event.parquet
      <data-path>/shared_metadata/{artist,album,embeddings}.parquet
      <data-path>/processed_<size>/{train_sessions,test_events,
                                    session_index}.parquet
      <data-path>/processed_<size>/item_popularity.npy
      <data-path>/processed_<size>/split_meta.json
  - 5b variant uses chunked polars load (10M rows/chunk) to keep
    peak RAM under control (single-shot read of the 50 GB parquet
    OOMs ~150 GB systems).
  - SUPPORTED_DATASETS adds yambda-50m, yambda-500m, yambda-5b.
  - main() takes --data-path for custom output root.
  - Verified end-to-end: 50m run completes in ~2 min, 5b in ~53 min
    (download dominates), output is byte-compatible with the dataset
    cache builder; TRITON training reaches steady state on the
    fresh data at 2050 sps.

utils.py:
  - Add env_path(key, default) @gin.configurable helper. Used as a
    gin macro so any string-valued binding can be overridden by an
    env var without editing the gin file.

train/gin/yambda_5b.gin:
  - Declare DATA_PATH = @env_path() macro with key="DLRM_DATA_PATH"
    and default="/apps/chcai/dlrm_data". Both new_path_prefix
    bindings (make_train_test_dataloaders and get_dataset) now
    consume %DATA_PATH. Setting DLRM_DATA_PATH=/some/path at run
    time redirects the dataset without a gin edit.

datasets/yambda.py:
  - Strip stale references to upstream-internal preprocessing in
    docstrings/comments; point at preprocess_public_data.py instead.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Every rank's first CUDA context was landing on GPU 0 (the default
device), so NCCL bound its communicators there before set_device
switched to the correct GPU. This leaked allocations on GPU 0 across
all 8 ranks and caused spurious OOMs during embedding-table init at
high HBM caps. Moving set_device above init_process_group and passing
device_id ensures each rank's NCCL state is created on its own GPU.
dlrm_v3/utils.py:
  - Replace the hardcoded manifold:// URL in _on_trace_ready_fn with a
    local trace_dir (default /tmp/dlrm_v3_traces). Filename now follows
    trace_step{step}_rank{rank}.json so per-rank captures don't collide.
  - Add _multi_window_schedule helper: a torch.profiler schedule that
    fires around each step in trace_steps=[...] (warmup before, active
    after, RECORD_AND_SAVE at the last active step). Lets one run
    capture multiple windows (e.g. early-step + steady-state) without
    re-running.
  - Make Profiler @gin.configurable. New knobs: trace_dir, trace_steps,
    wait, warmup, repeat, record_shapes, profile_memory, with_stack,
    with_flops, with_modules. Defaults preserve the prior single-window
    behavior (wait=10, warmup=20, active=50, repeat=1) so existing
    callers are unaffected.
  - Add run_results_dir(run_name) gin macro: resolves to
    <recommendation_v4>/results/<run_name>/. Used as the canonical
    output prefix for traces (and any future per-run artifacts).
    recommendation_v4/ is bind-mounted into the training container, so
    files written through this helper persist on the host.

train/gin/yambda_5b.gin:
  - Wire RUN_NAME env override -> run_results_dir(run_name=%RUN_NAME)
    -> Profiler.trace_dir. Sets trace_steps=[52], warmup=5, active=5
    (capture the 5-step window 52-56 on every rank).
  - Toggle train_eval_loop.output_trace = True so the profiler actually
    instantiates.

.gitignore:
  - Add results/ alongside the existing tmp/exps/ckpts/ runtime
    directories so per-run trace dumps don't show up in git status.

Verified: 8x MI350X TRITON yambda-5b run at bs=32 seq=2k drops
8 well-formed trace_step62_rank{0..7}.json files (~37 MB each) into
recommendation_v4/results/default/; visible on the host immediately.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…mes, trim_warmup

dlrm_v3/utils.py
  * Add run_results_dir(run_name) gin macro (resolves to
    <recommendation_v4>/results/<run_name>/) so trace artifacts persist on
    the host via the bind-mount.
  * Add _trim_warmup_from_trace post-processor: dedupes ProfilerStep spans
    by name first, then keeps only the last N unique steps' worth of
    events. Drops WARMUP-phase events that torch.profiler otherwise
    includes in the chrome trace.
  * Add trim_warmup kwarg (default True) on Profiler; auto-invokes the
    trimmer with N=active so the exported file matches the user-requested
    active window.
  * Filename now uses trace_steps[i] (the user-requested step) as the
    {step} label when multi-window mode is in use, instead of
    torch.profiler's internal step_num (which is off by ~warmup+active
    from the schedule trigger and confused everyone).

train/utils.py
  * Drop hardcoded `active=10` from the four `Profiler(rank, active=10)`
    call sites in train_loop / train_eval_loop. Positional args block
    gin overrides; once removed, Profiler.active in gin (default 50) and
    user gin bindings actually take effect.

train/gin/yambda_5b.gin
  * Fix env_path scoping collision: both DATA_PATH and RUN_NAME used the
    unscoped @env_path() configurable, which made the second binding's
    `env_path.key = "RUN_NAME"` overwrite the first's
    `env_path.key = "DLRM_DATA_PATH"`. Both names then resolved via the
    same env var (whichever was last), pointing DATA_PATH at trace_run2/
    and breaking dataset loads.
    Fixed by giving each call site its own scope: @data/env_path() and
    @run/env_path(), each with independent .key/.default bindings.
  * Set Profiler.trace_steps=[52], warmup=1, active=5; let trim_warmup
    default to True so the exported trace contains exactly 5 active
    ProfilerStep events.

Verified end-to-end:
  - Run with RUN_NAME=trace_run2 writes results/trace_run2/trace_step52_
    rank{0..7}.json (~19 MB each), step labels match trace_steps gin.
  - Triton cache persisted across runs: cold start ~6 min -> warm start
    ~2 min for autotune-to-first-step.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2048 was chosen for "round number near max_seq_len" but it slightly
overflows the per-sample budget: 3 * (2048//3) + 9 = 2055 > 2048, so
the dataset truncates ~7 UIH events to fit. 2039 makes the math exact
(3 * 679 + 9 = 2046 ≤ 2048) so no truncation.

Comment block expanded to document:
  - The 3-pool gather semantic (L//3 events per pool, interleaved
    chronologically).
  - The like-pool under-fill observation: like events are only 1.9%
    of yambda corpus and max user lifetime is ~28k events, so the
    like pool fills to ~105 events per anchor on average (not 679).
    TRITON's jagged attention skips the unfilled slots, so under-fill
    costs sequence budget but not GPU compute.

No code change. Cache for L=2039 already built and reused.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ol gather

Documents the fork's scope (yambda-5b on HSTU dlrm_v3 path), per-pool gather
strategy with effective fill table, and dataset statistics. Sections indexed
1–5 for navigation.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Adds env_int gin macro (companion to env_path) and wires
make_optimizer_and_shard.hbm_cap_gb through it so the per-rank HBM
ceiling can be tuned without editing the gin file.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Document the container image, dependency versions (native NGC torch 2.10,
triton 3.6, source-built fbgemm_gpu, torchrec 1.4.0, polars-u64-idx), gin
training configuration, and env vars needed to reproduce the 8x B200 run.
Adds three knobs, all driven from the gin file:
- make_model.bf16_training: enable bf16 autocast for the DlrmHSTU model.
- env_int macro: lets numeric gin values come from env vars (used by the
  existing hbm_cap_gb binding).
- apply_env_bootstrap.TRITON_FULL_AUTOTUNE: when False (default), three
  layer-norm/jagged triton kernels are pinned to a single Config so cold
  starts land at the same steady-state deterministically. When True, the
  full autotune search runs again — use this when changing shape, GPU,
  or triton/torch version, then re-pin from the discovered winners.

train_ranker._main_func now parses gin in two phases (skip_unknown=True
early, full pass after the heavy imports) so the bootstrap env var is set
BEFORE the triton kernel modules evaluate their @triton.autotune
decorators at module load time.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Mirrors the B200 layout with MI350X (gfx950, ROCm 7.2.1) specifics:
container image (rocm/primus:v26.3), fbgemm_gpu rebuild requirement (HEAD
nightly_rocm-2026.6.1 for ~30% step-time win over the shipped 2026.5.14),
the gin-driven TRITON_FULL_AUTOTUNE knob, and the measured perf ladder
from fp32/PYTORCH baseline (~28 d/epoch) down to the pinned bf16/TRITON
fast equilibrium (~7.6 d/epoch).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Merges per-rank chrome traces (results/<run>/trace_step{N}_rank{R}.json)
into a single Perfetto-loadable file, remapping pid/flow ids so
cross-rank events land on distinct tracks instead of collapsing onto one.
Used to produce the bf16 + pinned-autotune step-52 trace
(results/verify_rename/trace_step52.json.gz).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Refresh the B200 dependency versions to the latest validated stack
(torch 2.12.0a0 / CUDA 13.2, fbgemm_gpu built for sm_100+CUDA 13.2,
CUPTI 13.2), note 26.01 as an equivalent alternative, and record the
TRITON_FULL_AUTOTUNE=True setting for B200.
…iver)

Point fbgemm at the latest validated source commit (10b77573, 2026-06-01),
record the tested torchrec 1.7.0.dev nightly (1.4.0 stable fallback),
clarify the fbgemm wheel version string is the build date, and correct the
host/forward-compat driver CUDA versions (13.0 host / 595.58.03 compat).
After upgrading to torch 2.12 / torchrec 1.7 (B200-aligned), the pinned
configs from the torch 2.10 stack stopped landing on the fast equilibrium
because the torchrec 1.7 code path invokes these kernels at different
shape keys. Re-captured winners via a fresh autotune run and updated the
pin sites:

- _weighted_layer_norm_bwd_dx: BLOCK_N 8 -> 1 (num_warps 1 unchanged)
- split_2D_jagged_multirow:    BLOCK_N 1 / num_warps 2 -> BLOCK_N 8 / num_warps 1
- _layer_norm_bwd_dwdb:        BLOCK_N 128, num_warps 8 (unchanged - same winner on both stacks)

Verified: 3 consecutive checkpoints (steps 151/201/251) at 52.75-53.36 ms
deterministic on the new stack. Same equilibrium band as the torch 2.10
stack (51.5-53.0 ms).

Also adds a Stack B section to docs/training_recipe.md (MI350X) documenting
the torch 2.12 swap recipe (torch + torchvision + torchaudio + fbgemm
rebuild + torchrec git tag) so the MI350X recipe is dependency-aligned with
the B200 path.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Bumps the Stack B (torch 2.12 / torchrec 1.7) section to:
- fbgemm commit 10b77573 (same SHA as the B200 path) instead of 1509423
  (one cosmetic commit behind). Wheel rename 2026.6.1 -> 2026.6.2.
- Note that Stack A and Stack B use different pinned triton configs
  (already merged) and explain why (torchrec 1.7 invokes the kernels at
  different shape keys).
- Caveat: HSTU_HAMMER_KERNEL=PYTORCH fallback regresses to ~169 ms on
  Stack B (vs 107 ms on Stack A). TRITON is unaffected and remains the
  default; this only matters for PYTORCH-backend debugging.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Collapses the two-stack MI350X section into one canonical dependency
table: torch 2.12 / torchrec 1.7 / fbgemm @ 10b77573 — the same SHAs as
the B200 path. The image-native torch 2.10 / torchrec 1.4 / fbgemm
2026.5.14 path still works for development but the recipe doc now
documents the validated production stack only.

PYTORCH-backend caveat preserved.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Not relevant — TRITON is the documented default backend.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ecipe

Embedding sizes match the true entity counts in yambda-5b:
  item_id   9_390_000 -> 9_390_624
  artist_id 1_290_000 -> 1_293_395
  album_id  3_370_000 -> 3_367_692
  uid       1_000_000 -> 1_000_001
This eliminates the recurring "EmbeddingBoundsCheck ... Setting idx to
zero" warnings at training time.

Gin default raised to batch_size=1024 / eval_batch_size=1024. Measured
steady-state on the torch 2.12 + torchrec 1.7 + fbgemm HEAD stack with
TRITON HSTU + pinned triton configs: ~635 ms/step, ~12.9K sps, ~2.92
days/epoch vs ~7.6 days at bs=32. bs=2048 is feasible but only +3%
throughput at much higher autotune cost, so bs=1024 is the sweet spot.

Triton autotune pin for _weighted_layer_norm_bwd_dx now ships TWO
configs in the pinned list — BLOCK_N=1 (bs=32 winner) and BLOCK_N=8
(bs=1024 winner). Triton's autotune key=[BLOCK_D] dispatches the right
one per shape in <5 sec on cold start (vs ~30 sec from the full pool).
The other two pinned kernels (_layer_norm_bwd_dwdb, split_2D_jagged_multirow)
have identical winners at bs=32 and bs=1024 so they stay single-config.

Training-recipe doc drops the batch_size rows from both MI350X and B200
config tables — the recipe is intentionally batch-size-agnostic now that
the pin set covers a range.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Enable the multi-row, separated-RNG _ln_mul_dropout path on AMD MI350 (gfx950),
previously Blackwell-only. Batches rows per program and reuses a precomputed
dropout mask in the backward instead of one-program-per-row fused RNG; +5.6%
end-to-end (-> 14,222 global sps) at bs=1024 on yambda-5b.

- ops/utils.py: add is_amd_mi350() + use_separated_rng_ln_mul_dropout() gate.
- ops/triton/triton_hstu_linear.py: dispatch the fwd LN-dropout to the
  separated-RNG path via the new gate.
- ops/triton/triton_hstu_attention.py: pin fast nonkdim:16 fwd/persistent/bwd
  configs via pinned_or_full (TRITON_FULL_AUTOTUNE=1 still bypasses). Multi-config
  lists with an inline "add a new batch size" guide.
- scripts/launch_smoke_8gpu.sh: GPU clock sanity guard - log perf level + sclk,
  auto-restore 'auto' if a perf_determinism/manual/low lock is found (a half-clock
  lock uniformly slowed every Triton kernel ~1.9x and masked perf changes).
- docs/perf_opt.md: document the LN-dropout fix and the clock-lock caveat.

Co-authored-by: Cursor <cursoragent@cursor.com>
…ernel

Add an opt-in TrainPipelineSparseDist path that overlaps the embedding
input-distribution all-to-all with dense fwd/bwd. To make the embedding
collection pipelineable, the merged sparse KJT is now pre-built in the
dataloader (Samples.merged_sparse_features) and the model consumes it via a
_pipeline_mode forward that takes the batch as a single arg, so TorchRec's
tracer resolves the lookup input as a plain getattr off the batch.

- dataset.py: Samples.merged_sparse_features + merge_uih_candidate_kjts, built
  in collate_fn; wired into to()/record_stream()/pin_memory().
- dlrm_hstu.py: _pipeline_mode flag; forward unpacks the batch and preprocess
  accepts the prebuilt merged KJT (falls back to building it when absent).
- utils.py: _PipelineModelWrapper, build_train_pipeline, train_eval_loop
  use_pipeline branch + eval batch-arg; seed all RNGs in setup() for
  reproducible weight init.
- gin/launch: make_model.hammer_kernel selects TRITON vs PYTORCH (env override
  still honored); launch script defers to the gin default. use_pipeline
  defaults to False.

Validated on MI350/ROCm 8-GPU: embedding collection is pipelined (input-dist
a2a moves to hidden); model quality and throughput match the sequential path
(seeded A/B). The exposed embedding-output a2a still dominates the step, so
throughput is unchanged — pipelining is quality- and perf-neutral here.

Co-authored-by: Cursor <cursoragent@cursor.com>
Add a forward-in-time streaming path: slice the timeline into fixed-duration
windows (default 1 day), train window T then eval window T+1, enforcing no
future leakage (across-window + causal-history guarantees). Make it the
default mode in launch_smoke_8gpu.sh.

Window-reset overhead is hidden via a persistent worker pool + double
buffering (next window's index mask and first-batch prefetch overlap compute
on a background thread) and eval-window prefetch one window ahead, dropping
train/eval first-batch waits to ~1-3ms with no steady-state regression.
Window selection uses a lazily-built, mmap'd anchor-timestamp cache so the
default non-streaming path is unaffected.

Also harden trace export (best-effort: IO/permission failures warn instead of
crashing training) now that streaming enables output_trace by default, and
document the path + knobs in the README.

Co-authored-by: Cursor <cursoragent@cursor.com>
save_dmp_checkpoint.path now resolves from $CKPT_PATH and defaults to empty,
so checkpoints (a full DMP is ~100s of GB, and the streaming loop always saves
the final window) are off unless explicitly enabled. Also drop the stale
training-recipe sentence claiming native torch is kept — it contradicts the
dependency table, which replaces torch and keeps only the image's triton.

Co-authored-by: Cursor <cursoragent@cursor.com>
Add in-process trace postprocessing in the profiler on_trace_ready
callback to fix two ROCm/roctracer rendering artifacts that make MI350X
traces look wrong in Perfetto (the timing is correct, only the layout):

- _normalize_profilerstep_layout: collapse the fragmented GPU-side
  ProfilerStep#N spans (roctracer splits a step across the HIP null +
  compute streams) into one full-width span per step on the busiest
  compute stream, matching the CUDA look.
- _deoverlap_gpu_slices: pull back sub-us kernel end timestamps so
  back-to-back kernels don't touch/overlap; Perfetto otherwise nests the
  later (long) kernel inside the tiny epilogue and clips it to zero width,
  hiding kernels like _hstu_attn_bwd. Leaves a ~1ns gap (exact end==start
  is just as fatal as an overlap) and leaves real nesting untouched.

Both passes are gated behind _is_rocm() (torch.version.hip) so they are
complete no-ops on CUDA/B200, which don't have these artifacts. All
best-effort: failures degrade to a warning and never crash training.

Co-authored-by: Cursor <cursoragent@cursor.com>
Add _deoverlap_gpu_annotations to the trace-export postprocessing, the
annotation-boundary analog of the kernel de-overlap. Kineto projects the
forward/backward phase annotations (## user_forward ##, ## item_forward ##,
## stu_* ##, ...) onto the GPU stream as a chain of end-to-end siblings.
The absolute step timestamps are ~5.4e12 us, where a float64's quantum is
~1 ns, so a sibling boundary that should be coincident lands a few ns off;
when the earlier sibling ends at/after the next one's start, Perfetto nests
and clips the next span to a sliver -- e.g. the 100+ ms ## user_forward ##
vanishes on some ranks/steps purely by rounding luck.

Since annotations form a real nesting hierarchy (user_forward contains the
stu_* spans and their kernels), this walks the per-track slice stack and
only snaps a slice back when the next slice extends beyond it (siblings,
not parent/child), guarding against trimming into a span's own descendants.
It also snaps kernel tails that straddle an annotation boundary. Gated by
_is_rocm() (no-op on B200/CUDA) and best-effort like the other passes.

Verified end-to-end on an 8-rank MI350X run: ## user_forward ## renders
40/40 (was 9/40), total clipped annotations 1352 -> ~5.

Co-authored-by: Cursor <cursoragent@cursor.com>
Make streaming-train-eval crash-resumable and add general checkpoint
cadence controls:

- Atomic checkpoint saves (.tmp dir + rename), keep_last_n pruning, and
  swap-aside .old overwrite so a save can safely replace an existing
  train_ts dir; stale .tmp/.old swept on the next save.
- Per-rank RNG snapshot/restore for bit-equal dropout replay on resume;
  auto-latest-subdir resolution + (train_ts, batch_idx_in_window) resume
  hint so a run re-enters a partial window and skips already-trained
  batches exact-once.
- Three independent in-window checkpoint cadences via a pure, testable
  decision helper: per-window batch count, monotonic global step
  (e.g. every 1000 steps), and wall-clock interval (e.g. hourly,
  rank-0-decided + broadcast to keep the save barrier in lockstep).
- gin/env bindings for all cadences + a test-only die_at_step hook.

Tests: checkpoint_cadence_test.py (cadence precedence/triggers) and an
end-to-end baseline/interrupt/resume harness (streaming_resume_test.{sh,py})
that gates on functional invariants (RNG restored, correct resumed step,
atomic save, keep_last_n) plus a loose trajectory-closeness bound.

Co-authored-by: Cursor <cursoragent@cursor.com>
suachong and others added 14 commits June 24, 2026 19:56
Revert recommendation_v4/.gitignore to base. Local run artifacts and ad-hoc
analysis files are kept out of the repo via .git/info/exclude (local,
uncommitted) instead, removing one file from the PR diff.

Co-authored-by: Cursor <cursoragent@cursor.com>
Move the MLPerf event stream into mlperf_logging_utils.py: a MLPerfRunTracker
state machine owns the block/eval/run markers, progress metadata, and the
convergence decision (replacing ~145 lines of closures in
streaming_train_eval_loop), and MLPerfLogger.log_run_start emits submission
info + hyperparameters + INIT_STOP/RUN_START (collapsing the inline block in
train_ranker).

Convergence/EVAL_ACCURACY is fixed to per-window AUC: drop the
eval_accuracy_auc_mode knob (gin + loop param + launch_slurm passthrough).

Submission identity: SUBMISSION_ORG defaults to AMD, SUBMISSION_PLATFORM to
MI355X (was the org name — a bug), both overridable via
$MLPERF_SUBMISSION_PLATFORM.

Co-authored-by: Cursor <cursoragent@cursor.com>
The lifetime cumulative AUC always uses the exact binned backend now. Remove
the TRAIN_LIFETIME_AUC_MODE / EVAL_LIFETIME_AUC_MODE env overrides (and the
capped-only LIFETIME_AUC_WINDOW knob, now dead) from gin and launch_slurm.

Co-authored-by: Cursor <cursoragent@cursor.com>
Reduce the upstream launch_slurm.sh diff to just what MLPerf logging needs:
SCRATCH/REPO_MOUNT/DATA_MOUNT path portability (so outputs/log land off the
hardcoded /home/chcai,/apps/chcai) and the MLPerf env wiring (MLPERF_LOG_PATH,
AUC_THRESHOLD, MLPERF_LOGGING, MLPERF_SUBMISSION_PLATFORM,
MLPERF_TRAIN_LOSS_LOG_FREQ). Reverted the unrelated baseline changes (NCCL
GDR/IFNAME defaults, SMOKE/frozen run-shape, chmod/WORKER_TEE, HISTORY_STRATEGY,
lifetime-AUC passthroughs) to Chris' base.

Preserve the full kitchen-sink launcher as launch_slurm_suachong.sh for
personal multi-node use (self-reinvoke paths repointed to itself).

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Persist mlperf_run_started in the checkpoint so a resume relaunch
continues the SAME MLPerf run instead of re-emitting INIT_START/RUN_START
(compliance requires EXACTLY_ONE). Cold-vs-resume is detected from the
on-disk checkpoint before setup(); the log is truncated on a cold start
and appended on a resume so the single run's event stream accumulates
into one file. Legacy/cold checkpoints default the flag to False.

Co-authored-by: Cursor <cursoragent@cursor.com>
Default SEED back to 1 for a fixed, reproducible weight init out of the
box ($SEED=-1 still draws a fresh random seed per run).

Default AUC_THRESHOLD to 1.0 (unreachable) in both the gin binding and
the launch_slurm.sh fallback so a streaming-train-eval run trains through
all windows by default instead of early-stopping; set
$AUC_THRESHOLD=0.80275 for the MLPerf convergence target.

Co-authored-by: Cursor <cursoragent@cursor.com>
Restore NCCL_NET_GDR_LEVEL=5 + NCCL_DMABUF_ENABLE=1 defaults so RCCL does
true GPU<->NIC DMA over bnxt_re instead of host-memory staging (~+22%
throughput at 2 nodes; 65.7%->79.8% weak-scaling efficiency). The brcmrdma
host kernel ships the inbox peer-memory client, so GDR works with no
container/host changes; non-fatal fallback to host staging if peermem is
absent. Override with NCCL_NET_GDR_LEVEL=0.

Co-authored-by: Cursor <cursoragent@cursor.com>
The slimmed launch_slurm.sh has smoke-shaped run defaults (START_TS=150,
NUM_TRAIN_TS=1, NUM_TRAIN_BATCHES=20, per-window eval) and no SMOKE=1
toggle, so a bare submit is a short functional run — not the 299-window
reference. Document the bare submit as the smoke run and give the explicit
env-override command for the full reference sweep; drop the unimplemented
SMOKE=1 instructions.

Co-authored-by: Cursor <cursoragent@cursor.com>
dlrmv4: portable multi-node baseline + MLPerf compliance logging
Bring the HSTU/yambda-5b benchmark in line with MLPerf Training reference
conventions:
- add download_dataset.sh / verify_dataset.sh / run_and_time.sh wrappers
- add md5sums checksum file (placeholder hashes) for dataset verification
- restructure README.MD to the MLPerf spec (summaries, model+paper,
  hyperparameter table with tuning rules, quality target, eval frequency,
  steps-to-run)
- freeze requirements.txt to the exact Dockerfile/training_recipe versions
- add blank RCP placeholder (rcp/README.md) to be filled once convergence
  runs are generated

Co-authored-by: Cursor <cursoragent@cursor.com>
recommendation_v4: add MLPerf reference scripts, structure, and docs
@github-actions

Copy link
Copy Markdown

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@chriscai-amd chriscai-amd changed the title Chcai/dlrmv4 idfix [WIP] Adding DLRMv4 HSTU Jun 25, 2026
… configs

Remove subtrees unused by the yambda-5b TRITON training path:
- dlrm_v3/inference/ (incl thirdparty/loadgen)
- ops/cpp/ (CUTLASS CUDA kernels) and ops/triton_aot/ (AOT-inference kernels)
- generative_recommenders/research/ and its entrypoints (main.py,
  run_fractal_expansion.py, repo-root preprocess_public_data.py)
- configs/{ml-1m,ml-20m,ml-3b,amzn-books} and non-yambda train gins
  (keep yambda_5b.gin + debug.gin)
- ops/benchmarks/hstu_attention_bench.py (dangling ops.cpp import)

The TRITON (default) and PYTORCH kernel paths are unaffected: aot_* calls are
gated behind HammerKernel.TRITON_INFERENCE and ops/triton/* has no cpp/aot deps.
Validated by import smoke + a streaming-train-eval e2e smoke (rc=0).

Co-authored-by: Cursor <cursoragent@cursor.com>
@chriscai-amd chriscai-amd changed the title [WIP] Adding DLRMv4 HSTU [WIP][MLPerf 6.1][DLRMv4] Adding DLRMv4 HSTU Jun 25, 2026
chriscai-amd and others added 12 commits June 26, 2026 05:05
Change the default eval cadence from per-window (EVAL_EVERY_N_WINDOWS=1)
to data-fraction every 0.5% of data (EVAL_EVERY_N_WINDOWS=0,
EVAL_EVERY_DATA_PCT=0.005). Per-window spacing is uneven in data volume
since each daily window holds a different number of samples; the
data-fraction cadence yields ~200 evenly-spaced-by-compute eval points.

Updates the gin defaults and the launch_slurm.sh / launch_local.sh
fallbacks together so the two cadences are never both >0 (which raises a
ValueError at startup), and corrects the corresponding comments.

Co-authored-by: Cursor <cursoragent@cursor.com>
… e2e test

- broadcast total_train_anchors from rank-0 (avoid redundant mmap-gather +
  UID-hash recompute on every rank) and add a window-boundary dist.barrier()
  to prevent NCCL collective deadlock on skewed per-rank data prep.
- generalize streaming_resume_test.sh with --platform auto-detect for both
  NVIDIA B200 and AMD MI350/355 (container names, dataset paths, ckpt roots,
  node-local data staging), adding midwindow + multiwindow scenarios.
- extend streaming_resume_test.py with a `summarize` subcommand that parses
  logs for anchor-broadcast, window-barrier, eval-trigger and resume signals.
- env-gated barrier debug log in utils.py for test observability.

Validated end-to-end on B200: both midwindow and multiwindow scenarios PASS.

Co-authored-by: Cursor <cursoragent@cursor.com>
Platform-aware per-phase timeouts (meta64 NFS full-model checkpoints take
~9 min each vs B200 node-local NVMe), exposed via new --phase-timeout /
--mw-run-timeout overrides. Fix cleanup_workers self-kill: a plain
`pkill -f generative_recommenders` matched its own shell and SIGKILLed
cleanup mid-run, leaking trainer VRAM so the next phase OOM'd; now uses
bracketed patterns and blocks until trainers exit and VRAM drains.

Validated PASS end-to-end on MI350 (midwindow + multiwindow).

Co-authored-by: Cursor <cursoragent@cursor.com>
Add a plain-language header section to streaming_resume_test.sh explaining
why there are two scenarios and how they differ: midwindow guards resume
CORRECTNESS (land on the right batch/RNG/checkpoint within one window), while
multiwindow guards LIVENESS at window seams (all ranks cross a window boundary
in lockstep without an NCCL desync deadlock). Includes before/after timelines
of the boundary hang vs the rank-0-broadcast + dist.barrier fixes, why a
single-window test structurally cannot catch those bugs, and a comparison
table. Comments only; no behavior change.

Co-authored-by: Cursor <cursoragent@cursor.com>
dlrmv3 streaming: fix distributed sync + generalize checkpoint/resume…
Add per-table and global control over embedding table placement via gin,
overridable by env var. make_optimizer_and_shard now translates an
EMB_PLACEMENT global default plus per-table EMB_PLACEMENT_OVERRIDES
(hbm|uvm|uvm_caching|auto) into torchrec ParameterConstraints fed to the
EmbeddingShardingPlanner; "auto" leaves the table to the planner so the
default is byte-identical to the prior behavior (constraints=None). A new
env_str_map gin helper parses "name=val,name=val" with opt-in per-key merge
so a launch-time env tweak layers over the gin default. Also logs the
planner's ACTUAL per-table compute kernel ([emb-placement] plan: ...).

Validated end-to-end on 8x B200: force-HBM put every table on fused, and a
per-table override put uid on fused_uvm_caching while the rest stayed fused;
both trained cleanly.

Co-authored-by: Cursor <cursoragent@cursor.com>
When embedding all-to-all quantization is configured (SPARSE_A2A_FWD/BWD)
but cannot actually be enabled, _maybe_apply_qcomm_a2a previously logged a
warning and returned the unquantized sharders, silently running fp32. That
hides real misconfiguration and, worse, could leave some ranks on fp32 while
others run fp16 — desyncing the embedding collectives.

Now every "configured but not enabled" path raises (on all ranks, so the job
aborts consistently):
  - unknown precision string -> ValueError
  - codec registry build failure -> RuntimeError (chained from cause)
  - codec built but no EmbeddingCollectionSharder to bind it to -> RuntimeError

The legitimate no-quant default (forward=backward=fp32) still returns the
sharders untouched.
Flip SPARSE_A2A_FWD/BWD gin defaults from fp32 to fp16. An A/B run vs the fp32
interleaved baseline matched window AUC to within fixed-seed noise (mean
Δ≈-1e-6, max|Δ|≈5e-5, Pearson r=1.0 over the 0-54.5% data overlap) while
halving the embedding all-to-all wire volume (~6% end-to-end speedup). Grads on
yambda-5b stay well inside fp16 range (grad-clip=1.0, lr=1e-7), so fp16 backward
is convergence-neutral here; set both to "fp32" to restore the unquantized path.
Align the gin defaults so a no-override streaming-train-eval launch reproduces
the validated fp16 run:
  - START_TS 150 -> 0 and NUM_TRAIN_TS 149 -> 299 (sweep the full ts=0..298
    corpus instead of the dense ts=150..298 sub-range)
  - EVAL_HOLDOUT_TS -1 -> 299 (the window just past training; equivalent to the
    prior runtime-resolved value under the new sweep, now explicit)
  - CKPT_TIME_INTERVAL_S 0.0 -> 3600 (hourly saves)
Comments updated; each remains env-overridable to restore the old behavior.
…ch OOM hang

The embedding all-to-all is row-wise sharded, so a data-skewed batch routes a
few extremely hot IDs to a single owner rank, ballooning its a2a input tensor.
fbgemm's quant codec packs the payload via `torch.clamp(t, MIN, MAX).half()`,
where clamp() allocates a full-size fp32 temp before the cast (~2.5x input
peak). On the hottest shard that temp reached ~81.5 GiB and OOM'd the rank,
which then dropped out of the collective while peers blocked in the a2a -> a
deterministic ~30-min NCCL-watchdog hang (yambda-5b 4-node fp16, window 235 /
global step 43621, every run).

Fix: monkeypatch the codec to cast first then clamp in place
(`t.half().clamp_(MIN, MAX)`), dropping the full-size fp32 temp. Bit-for-bit
identical output (values above HALF_MAX cast to +inf which clamp_ maps back to
HALF_MAX; NaNs unchanged) and no throughput regression (strictly less memory
traffic). Validated: the patched run trains through step 43621 with 0 OOM /
0 watchdog timeouts and an unperturbed window-AUC trajectory.

Gin-configurable via make_optimizer_and_shard.qcomm_lowmem_clamp_cast
($QCOMM_LOWMEM_CODEC), ON by default, under a new RUNTIME PATCHES section in
yambda_5b.gin documenting the rationale; a no-op when the a2a is unquantized.

launch_slurm.sh: forward PG_TIMEOUT_S + TORCH_NCCL_* flight-recorder env into
the container (the instrumentation used to root-cause this hang).

Co-authored-by: Cursor <cursoragent@cursor.com>
Completes the qcomm low-memory codec knob: without this the gin
qlcc/env_int($QCOMM_LOWMEM_CODEC) binding could never see an env override inside
the container (it only saw the gin default). Now $QCOMM_LOWMEM_CODEC set at
submit time reaches the trainer, so the patch can be toggled per-run.

Co-authored-by: Cursor <cursoragent@cursor.com>
…pe overrides

Add EMB_SHARDING_OVERRIDES (gin/env) so individual embedding tables can be
pinned to a sharding type, orthogonal to the existing placement override.
Default OFF -> plan is byte-identical to the legacy all-ROW_WISE path.

Motivation: ROW_WISE routes every lookup of a hot ID to its single owner rank,
so a few popular albums/artists concentrate the embedding all-to-all onto one
rank; the burst scales ~linearly with global batch size and OOM'd the hot rank
(~208-238 GiB / 288) at window ~248 on the yambda-5b 4-node run. Moving
album_id/artist_id to COLUMN_WISE balances the a2a by rank regardless of hot-ID
skew. Validated by a reshard smoke: DCP loaded the ROW_WISE ckpt into the CW
plan cleanly, window_auc stayed ~0.78-0.80, and the hot rank sat at ~58%
(~120 GiB free) through the previously-OOM window.

- utils: _build_placement_constraints/make_optimizer_and_shard accept
  embedding_sharding_overrides and merge them into ParameterConstraints.
- gin: EMB_SHARDING_OVERRIDES env_str_map binding + rationale/example comments.
- launch_slurm: forward EMB_SHARDING_OVERRIDES (+ placement) into the container.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants